查看原文
其他

TensorFlow中的feed与fetch

gloomyfish OpenCV学堂 2019-03-29

TensorFlow中的feed与fetch

一:占位符(placeholder)与feed

当我们构建一个模型的时候,有时候我们需要在运行时候输入一些初始数据,这个时候定义模型数据输入在tensorflow中就是用placeholder(占位符)来完成。它的定义如下:

  1. def placeholder(dtype, shape=None, name=None):

其中dtype表示数据类型,shape表示维度,name表示名称。它支持单个数值与任意维度的数组输入。

1. 单个数值占位符定义

  1. a = tf.placeholder(tf.float32)

  2. b = tf.placeholder(tf.float32)

  3. c = tf.add(a, b)

当我们需要执行得到c的运行结果时候我们就需要在会话运行时候,通过feed来插入a与b对应的值,代码演示如下:

  1. with tf.Session() as sess:

  2.  result = sess.run(c, feed_dict={a:3, b:4})

  3.  print(result)

其中feed_dict就是完成了feed数据功能,feed中文有喂饭的意思,这里还是很形象的,对定义的模型来说,数据就是最好的食物,所以就通过feeddict来实现。

2. 多维数据

同样对于模型需要多维数据的情况下通过feed一样可以完成,定义二维数据的占位符,然后相加,代码如下:

  1. _x = tf.placeholder(shape=[None2], dtype=tf.float32, name="x")

  2. _y = tf.placeholder(shape=[None2], dtype=tf.float32, name="y")

  3. z = tf.add(_x, _y);

运行时候需要feed二维数组,实现如下:

  1. with tf.Session() as sess:

  2.  result = sess.run(z, feed_dict={_x:[[34], [12]], _y:[[88],[99]]})

  3.  print(result)

二:fetch用法

会话运行完成之后,如果我们想查看会话运行的结果,就需要使用fetch来实现,feed,fetch同样可以fetch单个或者多个值。

1. fetch单个值

矩阵a与b相乘之后输出结果,通过会话运行接受到值c_res这个就是fetch单个值,fetch这个单词在数据库编程中比较常见,这里称为fetch也比较形象。代码演示如下:

  1. import tensorflow as tf

  2. a = tf.Variable(tf.random_normal([33], stddev=3.0), dtype=tf.float32)

  3. b = tf.Variable(tf.random_normal([33], stddev=3.0), dtype=tf.float32)

  4. c = tf.matmul(a, b);

  5. init = tf.global_variables_initializer()

  6. with tf.Session() as sess:

  7.  sess.run(init)

  8.  c_res = sess.run(c)

  9.  print(c_res)

2. fetch多个值

还是以feed中代码为例,我们把feed与fetch整合在一起,实现feed与fetch多个值,代码演示如下:

  1. import tensorflow as tf

  2. _x = tf.placeholder(shape=[None2], dtype=tf.float32, name="x")

  3. _y = tf.placeholder(shape=[None2], dtype=tf.float32, name="y")

  4. z = tf.add(_x, _y);

  5. data = tf.random_normal([22], stddev=5.0)

  6. Y = tf.add(data, z)

  7. with tf.Session() as sess:

  8.  z_res, Y_res = sess.run((z, Y), feed_dict={_x:[[34], [12]], _y:[[88],[99]]})

  9.  print(z_res)

  10.  print(Y_res)

上述代码我们就fetch了两个值,这个就是feed与fetch的基本用法。下面我们就集合图像来通过feed与fetch实现一些图像ROI截取操作。代码演示如下:

  1. import tensorflow as tf

  2. import cv2 as cv

  3. # 通过opencv读取图像并显示

  4. src = cv.imread("D:/javaopencv/test.png")

  5. cv.imshow("input", src)

  6. _image = tf.placeholder(shape=[NoneNone3], dtype=tf.uint8, name="image")

  7. # ROI区域截取

  8. roi_image = tf.slice(_image, [401300], [180180, -1])

  9. #定义会话并执行

  10. with tf.Session() as sess:

  11.  slice = sess.run(roi_image, feed_dict={_image:src})

  12.  print(slice.shape)

  13.  cv.imshow("roi", slice)

  14.  cv.waitKey(0)

  15.  cv.destroyAllWindows()

运行结果显示: 原图:

脸部ROI截取


求木之长者,必固其根本;

欲流之远者,必浚其泉源!


关注【OpenCV学堂】

长按或者扫码下面二维码即可关注


    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存